import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


# 3*32*32
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        # Fully connect
        x = x.view(-1,16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

#Our model
net = Net()

#Define loss
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
# SGD with momentum
optimizer = optim.SGD(net.parameters(),lr = 0.001,momentum=0.9)



if __name__ == '__main__':
#Train the network
    for epoch in range(5):
        running_loss = 0.0
        for i, data in enumerate(trainloader,0):
            #get the input
            inputs,labels = data

            #warp them in Variable
            inputs, labels = Variable(inputs),Variable(labels)

            #zero the parameter gradents
            optimizer.zero_grad()

            #forward
            outputs = net(inputs)

            #loss
            loss = criterion(outputs, labels)

            #backward
            loss.backward()

            #update weights
            optimizer.step()

            #print statistics
            running_loss += loss.data[0]
            if i % 2000 == 1999:
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    print("Finished Training")

    print("Begin Testing")
    correct = 0
    total = 0
    for data in testloader:
        images,labels = data
        outputs = net(Variable(images))
        predicted = torch.max(outputs.data,1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))